Progress Report

Shusei Eshima

2018-05-14

library(topicdict)
library(purrr)
library(quanteda)
library(tibble)
library(ggplot2)
library(dplyr)
library(topicmodels)

Prepare Functions

Create model function:

create_model <- function(docs, seed_list, extra_k){
  set.seed(225)
  names(seed_list) <- 1:length(seed_list)
  dict <- quanteda::dictionary(seed_list)
  model <- topicdict_model(docs,
               dict = dict, extra_k = extra_k,
               remove_numbers = FALSE, 
               remove_punct = TRUE,
               remove_symbols = TRUE,
               remove_separators = TRUE)

  return(model)
}
# Check dispersion
tidy_seededlda_out <- function(model, res, n=15, show=F){
  # Create a nested data frame which contains W and Z
  post <- topicdict::posterior(res)
  topwords <- top_terms(post, n=n)
  topwords <- data.frame(topwords)
  colnames(topwords) <- paste0("EstTopic", 1:ncol(topwords))
  topwords %>%
    tidyr::gather(., key=EstTopic, value=Word) %>%
    mutate(Word = gsub("\\s.*$", "", Word)) -> otidy

  if(show){
    num_seededtopic <- length(model$seeds)
    print(top_terms(post, n)[, 1:num_seededtopic])
  }

  return(otidy)
}


list_to_tibble <- function(lobj){
  # Flatten list and get a tibble
  obj_len <- lobj %>% map(length) %>% flatten_int()
  element <- lobj %>% flatten_chr()

  tibble(SeedTopic = rep(paste0("EstTopic", 1:length(obj_len)), obj_len),
         Word=element
         ) -> res
  return(res)
}

count_appearence_list <- function(otidy, lobj){
  all_words <- lobj %>% flatten_chr()
  SeedTopicName <- paste0("EstTopic", 1:length(lobj))

  otidy %>% 
    right_join(., list_to_tibble(get("lobj")), by="Word") %>%
    mutate(count = ifelse(is.na(EstTopic), 0, 1)) %>%
    group_by(Word, SeedTopic) %>%
    summarize(count = sum(count)) -> organized


  organized %>%
    ggplot(., aes(x=factor(count))) +
    geom_histogram(stat="count") +
    xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
    theme_bw(base_size=15) +
    theme(plot.title = element_text(hjust = 0.5)) -> g1

  organized %>%
    ggplot(., aes(x=factor(count))) +
    geom_histogram(stat="count") +
    facet_wrap(~ SeedTopic, ncol=3) +
    xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
    theme_bw(base_size=15) +
    theme(plot.title = element_text(hjust = 0.5)) -> g2

  # Only check topics with seed
  otidy %>% 
    filter(EstTopic %in% get("SeedTopicName")) %>%
    right_join(., list_to_tibble(get("lobj")), by="Word") %>%
    mutate(count = ifelse(is.na(EstTopic), 0, 1)) %>%
    group_by(Word, SeedTopic) %>%
    summarize(count = sum(count)) -> organized_

  organized_ %>%
    ggplot(., aes(x=factor(count))) +
    geom_histogram(stat="count") +
    facet_wrap(~ SeedTopic, ncol=3) +
    xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics\n(Only topics with keywords)") +
    theme_bw(base_size=15) +
    theme(plot.title = element_text(hjust = 0.5)) -> g3

  return(list(g1, g2, g3))
}
library(tm)

get_lda_result <- function(doc_folder, seed_list, iter_num, k, topicvec=1:k, show_n=15){

  # Prepare Data
  corpus <- Corpus(DirSource(doc_folder))
  strsplit_space_tokenizer <- function(x)
      unlist(strsplit(as.character(x), "[[:space:]]+"))

  dtm <- DocumentTermMatrix(corpus,
                           control = list(tokenize=strsplit_space_tokenizer, 
                           stopwords = F, tolower = T, 
                           stemming = F, wordLengths = c(1, Inf)))

  lda <- LDA(dtm, k = k, control = list(seed = 225, iter=iter_num), method="Gibbs")


  all_words <- seed_list %>% flatten_chr()

  tidytext::tidy(lda) %>%
    group_by(topic) %>%
    top_n(show_n, beta) %>%
    rename(topics = topic, Word = term) %>%
    select(-beta) %>%
    right_join(., list_to_tibble(get("seed_list")), by="Word") %>%
    mutate(count = ifelse(is.na(topics), 0, 1)) %>%
    group_by(Word, SeedTopic) %>%
    summarize(count = sum(count)) -> organized


  organized %>%
    ggplot(., aes(x=factor(count))) +
    geom_histogram(stat="count") +
    xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
    theme_bw(base_size=15) +
    theme(plot.title = element_text(hjust = 0.5)) -> g1

  organized %>%
    ggplot(., aes(x=factor(count))) +
    geom_histogram(stat="count") +
    facet_wrap(~ SeedTopic, ncol=3) +
    xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
    theme_bw(base_size=15) +
    theme(plot.title = element_text(hjust = 0.5)) -> g2

  return(list(g1, g2))
}

Check dispersion with similation data

True K = 15

data_folder <- tempfile()
seed_list <- create_sim_data(saveDir=paste0(data_folder, "Sim1"), D=1000, K=15, TotalV=3000, alpha=0.1, beta_r=0.1, beta_s=0.1, p=c(rep(0.2, 5),rep(0.12, 5),rep(0.05, 5)), lambda=200, seeds_len=5)
[1] "Finished: "
   user  system elapsed 
  9.428   0.943  10.735 
seed_list <- lapply(seed_list, function(x){tolower(x)})
seed_list_full <- seed_list
seed_list <- seed_list_full[c(3,4,7,8,14,15)]
doc_folder <- paste0(data_folder, "Sim1", "/W")
docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
explore_ <- explore(docs,
             remove_numbers = FALSE, # For simulation, make it false
             remove_punct = TRUE,
             remove_symbols = TRUE,
             remove_separators = TRUE)
explore_$visualize_dict_prop(seed_list)

K = 15

model <- create_model(docs, seed_list, extra_k=9)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=15)
[[1]]


[[2]]

K = 25

model <- create_model(docs, seed_list, extra_k=19)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=25)
[[1]]


[[2]]

K = 50

model <- create_model(docs, seed_list, extra_k=44)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=50)
[[1]]


[[2]]

K = 100

model <- create_model(docs, seed_list, extra_k=94)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=100)
[[1]]


[[2]]

K = 25, only middle proportion

seed_list <- seed_list_full[c(3,4,5)]
explore_$visualize_dict_prop(seed_list)

model <- create_model(docs, seed_list, extra_k=22)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=25)
[[1]]


[[2]]

K = 25, only low proportion

seed_list <- seed_list_full[c(11,14,15)]
explore_$visualize_dict_prop(seed_list)

model <- create_model(docs, seed_list, extra_k=22)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=25)
[[1]]


[[2]]

K = 25, No Contamination

seed_list <- seed_list_full[c(6,5,7,9,12)]
explore_$visualize_dict_prop(seed_list)

model <- create_model(docs, seed_list, extra_k=20)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=25)
[[1]]


[[2]]

Contamination 1, K = 25

seed_list <- list(
                  c(seed_list_full[[6]][1:4],seed_list_full[[13]][2]),
                  c(seed_list_full[[5]][1:4],seed_list_full[[7]][2]),
                  c(seed_list_full[[7]][1:4], seed_list_full[[3]][2]),
                  c(seed_list_full[[9]][1:4], seed_list_full[[5]][2]),
                  c(seed_list_full[[12]][1:4], seed_list_full[[8]][2])
                  ) 
explore_$visualize_dict_prop(seed_list)

model <- create_model(docs, seed_list, extra_k=20)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=25)
[[1]]


[[2]]

Contamination 2, K = 25

seed_list <- list(
                  c(seed_list_full[[6]][1:5]),
                  c(seed_list_full[[5]][1:4],seed_list_full[[7]][2]),
                  c(seed_list_full[[3]][1:3], seed_list_full[[3]][2], seed_list_full[[5]][2]),
                  c(seed_list_full[[9]][1:3], seed_list_full[[5]][3], seed_list_full[[7]][3]),
                  c(seed_list_full[[12]][1:3], seed_list_full[[8]][2], seed_list_full[[6]][2])
                  ) 
explore_$visualize_dict_prop(seed_list)

model <- create_model(docs, seed_list, extra_k=20)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

True K = 45

data_folder <- tempfile()
seed_list <- create_sim_data(saveDir=paste0(data_folder, "Sim1"), D=1000, K=45, TotalV=3000, alpha=0.1, beta_r=0.1, beta_s=0.1, p=c(rep(0.2, 15),rep(0.12, 15),rep(0.05, 15)), lambda=200, seeds_len=5)
[1] "Finished: "
   user  system elapsed 
  7.968   1.137   9.199 
seed_list <- lapply(seed_list, function(x){tolower(x)})
seed_list_full <- seed_list
seed_list <- seed_list_full[c(5,12,19,24,36,40)]
doc_folder <- paste0(data_folder, "Sim1", "/W")
docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
explore_ <- explore(docs,
             remove_numbers = FALSE, # For simulation, make it false
             remove_punct = TRUE,
             remove_symbols = TRUE,
             remove_separators = TRUE)
explore_$visualize_dict_prop(seed_list)

K = 50

model <- create_model(docs, seed_list, extra_k=44)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=50)
[[1]]


[[2]]

K = 80

model <- create_model(docs, seed_list, extra_k=74)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=80)
[[1]]


[[2]]

K = 100

model <- create_model(docs, seed_list, extra_k=94)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=100)
[[1]]


[[2]]

Contamination 1, K = 60

seed_list <- list(
                  c(seed_list_full[[5]][1:4],seed_list_full[[45]][2]),
                  c(seed_list_full[[12]][1:4],seed_list_full[[33]][2]),
                  c(seed_list_full[[19]][1:4], seed_list_full[[16]][2]),
                  c(seed_list_full[[24]][1:4], seed_list_full[[21]][2]),
                  c(seed_list_full[[36]][1:4], seed_list_full[[9]][2])
                  ) 
explore_$visualize_dict_prop(seed_list)

model <- create_model(docs, seed_list, extra_k=55)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=60)
[[1]]


[[2]]

Contamination 2, K = 60

seed_list <- list(
                  c(seed_list_full[[5]][1:5]),
                  c(seed_list_full[[12]][1:4],seed_list_full[[9]][2]),
                  c(seed_list_full[[19]][1:3], seed_list_full[[22]][2], seed_list_full[[40]][2]),
                  c(seed_list_full[[24]][1:3], seed_list_full[[33]][3], seed_list_full[[28]][3]),
                  c(seed_list_full[[36]][1:3], seed_list_full[[12]][2], seed_list_full[[16]][2])
                  ) 
explore_$visualize_dict_prop(seed_list)

model <- create_model(docs, seed_list, extra_k=55)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)
[[1]]


[[2]]


[[3]]

Amy’s Data Look in Detail

Read Data

doc_folder <- paste0("/Users/Shusei/Dropbox/Study/My_Research/TreeStructuredTopicModel/Papers/replication/Catalinac/data/docs") # Data from original data Document-Term Matrix
docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
explore_ <- explore(docs,
             remove_numbers = TRUE,
             remove_punct = TRUE,
             remove_symbols = TRUE,
             remove_separators = TRUE)

Self-Selected Keywords

# Remove overlapping words for better comparison
seed_list <- list(c("農業 整備 漁業 開発 水産"), # agriculture, fishing industry
                  c("税 消費 暮らし 景気 税金"), # tax
                  c("介護 高齢 保険 長寿 健康"), # aging society
                  c("軍事 戦争 軍 自衛隊 平和"))
seed_list <- lapply(seed_list, function(x){strsplit(x, " ")[[1]]})
g <- explore_$visualize_dict_prop(seed_list)
g

ggsave("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/Catalinac.pdf", g, width=5, height=4, family="Japan1GothicBBB")

K + 1

# SeededLDA eight keywords
model <- create_model(docs, seed_list, extra_k=1)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list)
      1          2            3          4         
 [1,] "社会"     "党"         "年金"     "党"      
 [2,] "整備 [✓]" "日本"       "制度"     "政治"    
 [3,] "教育"     "憲法"       "改革"     "日本"    
 [4,] "推進"     "増税"       "地域"     "国民"    
 [5,] "図る"     "政治"       "安心"     "税 [2]"  
 [6,] "作り"     "消費 [✓]"   "医療"     "消費 [2]"
 [7,] "実現"     "守る"       "円"       "共産"    
 [8,] "地域"     "税 [✓]"     "実現"     "自民党"  
 [9,] "充実"     "国民"       "地方"     "企業"    
[10,] "福祉"     "社会"       "日本"     "守る"    
[11,] "対策"     "暮らし [✓]" "介護 [✓]" "反対"    
[12,] "産業"     "自民党"     "社会"     "選挙"    
[13,] "振興"     "円"         "支援"     "平和 [✓]"
[14,] "豊か"     "保障"       "ひと"     "廃止"    
[15,] "政治"     "平和 [4]"   "国"       "民主"    
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=6)
[[1]]


[[2]]

K + 16

# SeededLDA eight keywords
model <- create_model(docs, seed_list, extra_k=16)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list)
      1          2            3          4         
 [1,] "整備 [✓]" "消費 [✓]"   "制度"     "党"      
 [2,] "道路"     "税 [✓]"     "年金"     "日本"    
 [3,] "地域"     "憲法"       "医療"     "共産"    
 [4,] "産業"     "党"         "介護 [✓]" "国民"    
 [5,] "振興"     "増税"       "支援"     "政治"    
 [6,] "推進"     "日本"       "実現"     "税 [2]"  
 [7,] "道"       "平和 [4]"   "充実"     "消費 [2]"
 [8,] "促進"     "暮らし [✓]" "雇用"     "増税"    
 [9,] "建設"     "守る"       "負担"     "反対"    
[10,] "県"       "政治"       "安心"     "企業"    
[11,] "交通"     "社会"       "教育"     "民主"    
[12,] "実現"     "アメリカ"   "保険 [✓]" "守る"    
[13,] "図る"     "改悪"       "対策"     "選挙"    
[14,] "早期"     "保障"       "社会"     "つらぬく"
[15,] "都市"     "企業"       "企業"     "基地"    
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=20)
[[1]]


[[2]]

K + 46

# SeededLDA eight keywords
model <- create_model(docs, seed_list, extra_k=46)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list)
      1          2            3          4         
 [1,] "整備 [✓]" "党"         "制度"     "憲法"    
 [2,] "振興"     "共産"       "医療"     "消費 [2]"
 [3,] "道路"     "日本"       "年金"     "税 [2]"  
 [4,] "道"       "政治"       "介護 [✓]" "平和 [✓]"
 [5,] "農業 [✓]" "国民"       "支援"     "アメリカ"
 [6,] "地域"     "税 [✓]"     "保険 [✓]" "改悪"    
 [7,] "産業"     "消費 [✓]"   "負担"     "企業"    
 [8,] "図る"     "増税"       "雇用"     "社会"    
 [9,] "農林"     "民主"       "充実"     "保障"    
[10,] "促進"     "暮らし [✓]" "安心"     "守る"    
[11,] "推進"     "自民党"     "実現"     "年金"    
[12,] "県"       "税金 [✓]"   "保育"     "増税"    
[13,] "建設"     "やめる"     "拡充"     "戦争 [✓]"
[14,] "交通"     "選挙"       "費"       "反対"    
[15,] "水産 [✓]" "目指す"     "保障"     "雇用"    
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=50)
[[1]]


[[2]]

Original Keywords

# Remove overlapping words for better comparison
topic52 <- c("農業 産業 整備 漁業 開発")
topic62 <- c("復興 連立 被災 災害 ひと")
topic63 <- c("政治 主義 自由 社会 民主")
topic20 <- c("税 消費 廃止 国民 日本")
topic58 <- c("企業 教育 中小 充実 図る")

seed_list <- list(topic52, topic62, topic63, topic20, topic58)
seed_list <- lapply(seed_list, function(x){strsplit(x, " ")[[1]]})
g <- explore_$visualize_dict_prop(seed_list)
g

ggsave("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/Catalinac2.pdf", g, width=5, height=4, family="Japan1GothicBBB")

k = 6 (K+1)

Seeded LDA

model <- create_model(docs, seed_list, extra_k=1)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list)
      1          2          3          4          5         
 [1,] "整備 [✓]" "円"       "政治 [✓]" "党"       "教育 [✓]"
 [2,] "政治 [3]" "年金"     "改革"     "日本 [✓]" "税 [4]"  
 [3,] "社会 [3]" "制度"     "国民 [4]" "国民 [✓]" "政治 [3]"
 [4,] "推進"     "医療"     "日本 [4]" "共産"     "福祉"    
 [5,] "地域"     "政権"     "選挙"     "政治 [3]" "守る"    
 [6,] "図る [5]" "兆"       "自民党"   "増税"     "平和"    
 [7,] "作り"     "郵政"     "党"       "税 [✓]"   "実現"    
 [8,] "豊か"     "無駄"     "新しい"   "消費 [✓]" "消費 [4]"
 [9,] "振興"     "民営"     "政権"     "憲法"     "充実 [✓]"
[10,] "実現"     "実現"     "社会 [✓]" "守る"     "円"      
[11,] "産業 [✓]" "廃止 [4]" "実現"     "自民党"   "社会 [3]"
[12,] "福祉"     "金"       "腐敗"     "企業 [5]" "中小 [✓]"
[13,] "発展"     "ひと [✓]" "民主 [✓]" "反対"     "減税"    
[14,] "対策"     "税金"     "ひと [2]" "暮らし"   "企業 [✓]"
[15,] "国際"     "財源"     "世界"     "民主 [3]" "年金"    
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=6)
[[1]]


[[2]]

k = 15 (K+10)

Seeded LDA

model <- create_model(docs, seed_list, extra_k=10)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list)
      1          2          3          4          5         
 [1,] "整備 [✓]" "年金"     "政治 [✓]" "党"       "教育 [✓]"
 [2,] "産業 [✓]" "円"       "自民党"   "日本 [✓]" "福祉"    
 [3,] "推進"     "政権"     "国民 [4]" "消費 [✓]" "図る [✓]"
 [4,] "地域"     "制度"     "改革"     "共産"     "社会 [3]"
 [5,] "振興"     "無駄"     "党"       "国民 [✓]" "充実 [✓]"
 [6,] "作り"     "地域"     "選挙"     "税 [✓]"   "企業 [✓]"
 [7,] "道路"     "交代"     "金"       "政治 [3]" "守る"    
 [8,] "図る [5]" "廃止 [4]" "腐敗"     "増税"     "農業 [1]"
 [9,] "豊か"     "ひと [✓]" "主義 [✓]" "企業 [5]" "制度"    
[10,] "社会 [3]" "医療"     "税 [4]"   "反対"     "中小 [✓]"
[11,] "県"       "金"       "消費 [4]" "民主 [3]" "進める"  
[12,] "実現"     "実現"     "権"       "守る"     "年金"    
[13,] "道"       "税金"     "廃止 [4]" "選挙"     "生活"    
[14,] "建設"     "兆"       "企業 [5]" "つらぬく" "実現"    
[15,] "促進"     "民主 [3]" "日本 [4]" "基地"     "医療"    
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=15)
[[1]]


[[2]]

k = 30 (K+25)

Seeded LDA

model <- create_model(docs, seed_list, extra_k=25)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list)
      1          2          3          4          5         
 [1,] "整備 [✓]" "増税"     "政治 [✓]" "日本 [✓]" "企業 [✓]"
 [2,] "道路"     "党"       "税 [4]"   "党"       "教育 [✓]"
 [3,] "地域"     "憲法"     "消費 [4]" "国民 [✓]" "図る [✓]"
 [4,] "振興"     "守る"     "社会 [✓]" "共産"     "充実 [✓]"
 [5,] "道"       "円"       "自由 [✓]" "税 [✓]"   "中小 [✓]"
 [6,] "産業 [✓]" "自民党"   "国民 [4]" "政治 [3]" "対策"    
 [7,] "県"       "反対"     "自民党"   "消費 [✓]" "福祉"    
 [8,] "漁業 [✓]" "民営"     "廃止 [4]" "増税"     "制度"    
 [9,] "促進"     "税 [4]"   "金"       "民主 [3]" "確立"    
[10,] "建設"     "民主 [3]" "選挙"     "守る"     "安定"    
[11,] "早期"     "庶民"     "実現"     "反対"     "農業 [1]"
[12,] "交通"     "改悪"     "民主 [✓]" "選挙"     "実現"    
[13,] "農業 [✓]" "郵政"     "権"       "自民党"   "進める"  
[14,] "推進"     "日本 [4]" "党"       "主人公"   "つとめる"
[15,] "農林"     "野党"     "院"       "企業 [5]" "社会 [3]"
[[1]]


[[2]]


[[3]]

Standard LDA

get_lda_result(doc_folder, seed_list, iter_num, k=30)
[[1]]


[[2]]

New Visualization

diagnosis_topic_recovery_heatmap <- function(post, n=25, title_=T,
                        seed_list=NULL,
                        topicvec=c(), merge=list()){
  topwords <- top_terms(post, n=n)
  topwords <- data.frame(topwords)
  colnames(topwords) <- paste0("EstTopic", 1:ncol(topwords))

  topwords <- tidyr::gather(topwords, key=EstTopic, value=Word) %>%
                mutate(Word = gsub("\\s.*$", "", Word))

  topwords %>%
    mutate(RawWord = Word) %>%
    tidyr::separate(Word,
        into=c("word_id", "TrueTopic"),
        sep="t") %>%
    mutate(TrueTopic = paste0("True", as.character(TrueTopic))) -> res_

  merge_length <- length(merge)
  if(merge_length != 0){
    # Merge Topics
    for(i in 1:merge_length){
      m <- merge[[i]]
      mt <- paste0("True", m)

      res_ %>%
        mutate(TrueTopic=replace(TrueTopic, TrueTopic==mt[1], mt[3])) %>%
        mutate(TrueTopic=replace(TrueTopic, TrueTopic==mt[2], mt[3])) -> res_
    }
  }

  res_ %>%
    group_by(EstTopic, TrueTopic) %>%
    summarise(counts = n()) %>%
    ungroup() %>%
    group_by(EstTopic) %>% 
    mutate(topicsum = sum(counts)) %>%
    ungroup() %>%
    mutate(Proportion = counts / topicsum * 100) -> res_

  if(!is.null(seed_list)){
    # Use only topics with keywords
    num <- length(seed_list)
    seed_list_name <- paste0("EstTopic", 1:num)
    res_ %>%
      filter(EstTopic %in% get("seed_list_name")) -> res_
  }

  num <- length(unique(res_$EstTopic))
  if(is.null(topicvec)){
    res_ %>%
      group_by(EstTopic) %>%
      top_n(1, Proportion) %>%
      mutate(forranking = as.integer(gsub("EstTopic", "", EstTopic))) %>%
      arrange(forranking) %>%
      select(EstTopic) -> topicvec 
    topicvec <-  unique(as.integer(gsub("EstTopic", "", topicvec$EstTopic)))
  }else if(length(topicvec) != num){
    message("topicvec length does not match")
    topicvec <- 1:num
  }

  truenum <- length(unique(res_$TrueTopic))

  title <- paste0("Seeded LDA: Top ", as.character(n), " words")

  g <- ggplot(res_, aes(EstTopic, TrueTopic)) +
        geom_tile(aes(fill=Proportion)) + 
        scale_fill_gradient(limits=c(0, 100), low="#e8e8e8", high="#0072B2", name = "Proportion") +
        scale_x_discrete(limits = rev(paste0("EstTopic", topicvec))) +
        coord_flip() +
        scale_y_discrete(limits = paste0("True", 1:truenum)) +
        xlab("Estimated Topics") + ylab("True Topic") + theme_bw(base_size=13)
  
  if(title_){
    g <- g + ggtitle(title) +
        theme(plot.title = element_text(hjust = 0.5))
  }


  return(g)
}
library(grid)
library(gridExtra)
run_simulations <- function(trueK, estimatedK, seeds_len=6,
                            seed_only=F, seed_contamination=0){
  # Create Combinations
  combinations <- expand.grid(trueK, estimatedK)
  num_combinations <- nrow(combinations)

  # Run Simulations
  for(s in 1:num_combinations){
    trueK_ <- combinations[s, 1]
    estimatedK_ <- combinations[s, 2]

    # Create Data
    set.seed(225)
    data_folder <- tempfile()
    seed_list <- create_sim_data(saveDir=paste0(data_folder, "Sim1"), 
                                 D=1000, K=trueK_, TotalV=3000, alpha=0.1, 
                                 beta_r=0.1, beta_s=0.1, 
                                 p=rep(0.15, trueK_) + rnorm(trueK_, mean=0, sd=0.04), 
                                 lambda=200, seeds_len=seeds_len)
    seed_list <- lapply(seed_list, function(x){tolower(x)})

    # Seed contamination
    if(seed_contamination != 0){
      for(i in 1:seed_contamination){
        seed_list <- lapply(seed_list, function(x){
                      x[sample(1:seeds_len, 1)] <- seed_list[[sample(1:trueK_, 1)]][sample(1:seeds_len, 1)] 
                      return(x)
                      })
      }
    }

    # Fit the model
    extra_k_ <- estimatedK_ - length(seed_list)
    if(extra_k_ < 0){
      message("extra_k is negative, setting it to 0")
      extra_k_ <- 0
      seed_list <- seed_list[1:estimatedK_]
    }
    doc_folder <- paste0(data_folder, "Sim1", "/W")
    docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
    model <- create_model(docs, seed_list, extra_k=extra_k_)
    res <- topicdict_train(model, iter = iter_num)
    post <- topicdict::posterior(res)

    if(seed_only){
      g <- diagnosis_topic_recovery_heatmap(post, 15, title_=F, seed_list=seed_list)
    }else{
      g <- diagnosis_topic_recovery_heatmap(post, 15, title_=F)
    }

    # Save
    saveRDS(g, file = paste0("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/obj/", 
                            "fig_T", trueK_, "_E", estimatedK_, ".obj"))
    saveRDS(post, file = paste0("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/obj/", 
                            "post_T", trueK_, "_E", estimatedK_, ".obj"))

    message(paste0("Done: ", s, "/", num_combinations))
  }
}

create_simulation_figure <- function(trueK, estimatedK, title_="Simulation Results"){
  # Create Combinations
  combinations <- expand.grid(trueK, estimatedK) %>%
                    arrange(rev(Var2)) 
  num_combinations <- nrow(combinations)

  # Load Data
  figures <- list()
  for(s in 1:num_combinations){
    trueK_ <- combinations[s, 1]
    estimatedK_ <- combinations[s, 2]

    figures[[s]] <- readRDS(file = paste0("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/obj/", 
                            "fig_T", trueK_, "_E", estimatedK_, ".obj"))
  }

  ### Create Figure
  # Get Information
  g1 <- ggplotGrob(figures[[1]])
  id.legend <- grep("guide", g1$layout$name)
  legend <- g1[["grobs"]][[id.legend]]


  # Edit Figure
  edit_figure <- theme(legend.position="none",
                       axis.title.x=element_blank(),
                       axis.title.y=element_blank(),
                       axis.text.x=element_blank(),
                       axis.text.y=element_blank())
  figures <- lapply(figures, function(x){x + edit_figure})

  # New Pictures cf. https://stackoverflow.com/a/11093069/4357279
  g <- arrangeGrob(grobs=figures, 
               nrow = length(estimatedK),
               right = legend,
               top = textGrob(title_),
               left = textGrob("Estimated Topic", rot = 90, vjust = 1),
               bottom = textGrob("True Topic", vjust = -0.1))

  grid.draw(g) # Show plot
}

multiple_simulations <- function(trueK,
                                 estimatedK,
                                 seed_only=F,
                                 seed_contamination=0){

  # Run Simulation
  run_simulations(trueK, estimatedK, seed_only=seed_only, seed_contamination=seed_contamination)

  # Create Figure
  create_simulation_figure(trueK, estimatedK)
}
# How many "true" topic can keywords collect?
multiple_simulations(trueK=c(5,15,25,35), estimatedK=c(5,15,25,35))

[1] "Finished: "
   user  system elapsed 
 10.108   1.706  11.849 
[1] "Finished: "
   user  system elapsed 
  8.461   0.996   9.494 
[1] "Finished: "
   user  system elapsed 
  7.844   1.054   8.928 
[1] "Finished: "
   user  system elapsed 
  7.460   0.914   8.393 
[1] "Finished: "
   user  system elapsed 
 10.134   1.666  11.820 
[1] "Finished: "
   user  system elapsed 
  7.605   1.090   8.710 
[1] "Finished: "
   user  system elapsed 
  7.720   1.006   8.743 
[1] "Finished: "
   user  system elapsed 
  7.849   1.024   8.891 
[1] "Finished: "
   user  system elapsed 
 10.019   1.698  11.738 
[1] "Finished: "
   user  system elapsed 
  8.187   1.069   9.278 
[1] "Finished: "
   user  system elapsed 
  7.873   1.053   8.953 
[1] "Finished: "
   user  system elapsed 
  7.443   0.848   8.312 
[1] "Finished: "
   user  system elapsed 
 10.060   1.699  11.780 
[1] "Finished: "
   user  system elapsed 
  8.053   1.181   9.252 
[1] "Finished: "
   user  system elapsed 
  7.707   0.921   8.650 
[1] "Finished: "
   user  system elapsed 
  7.707   1.026   8.752 
# How many "true" topic can keywords collect?
multiple_simulations(trueK=c(5,15,25,35), estimatedK=c(5,15,25,35),
                     seed_contamination=2)

[1] "Finished: "
   user  system elapsed 
 10.447   1.823  12.291 
[1] "Finished: "
   user  system elapsed 
  8.428   1.089   9.535 
[1] "Finished: "
   user  system elapsed 
  7.633   0.932   8.587 
[1] "Finished: "
   user  system elapsed 
  7.663   0.951   8.630 
[1] "Finished: "
   user  system elapsed 
 10.217   1.781  12.019 
[1] "Finished: "
   user  system elapsed 
  8.067   1.106   9.195 
[1] "Finished: "
   user  system elapsed 
  7.929   1.091   9.037 
[1] "Finished: "
   user  system elapsed 
  7.391   1.046   8.458 
[1] "Finished: "
   user  system elapsed 
 10.192   1.758  11.970 
[1] "Finished: "
   user  system elapsed 
  8.415   1.135   9.568 
[1] "Finished: "
   user  system elapsed 
  7.730   1.080   8.832 
[1] "Finished: "
   user  system elapsed 
  7.512   1.067   8.597 
[1] "Finished: "
   user  system elapsed 
 10.083   1.825  11.933 
[1] "Finished: "
   user  system elapsed 
  7.948   1.191   9.157 
[1] "Finished: "
   user  system elapsed 
  7.586   1.004   8.607 
[1] "Finished: "
   user  system elapsed 
  7.423   1.019   8.459